package cc.shacocloud.mirage.restful.cors;

import cc.shacocloud.mirage.restful.HttpRequest;
import cc.shacocloud.mirage.restful.HttpResponse;
import cc.shacocloud.mirage.restful.http.HttpHeaderMap;
import cc.shacocloud.mirage.restful.util.CorsUtils;
import cc.shacocloud.mirage.utils.LogFormatUtils;
import cc.shacocloud.mirage.utils.collection.CollUtil;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.util.AsciiString;
import io.vertx.core.Future;
import io.vertx.core.http.HttpHeaders;
import io.vertx.core.http.HttpMethod;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.Nullable;

import java.util.ArrayList;
import java.util.List;

/**
 * CORS W3C推荐标准的定义。
 */
@Slf4j
public class DefaultCorsProcessor implements CorsProcessor {
    
    /**
     * Vary HTTP响应头决定如何匹配未来的请求头，以决定是否可以使用缓存的响应，而不是从原始服务器请求一个新的响应。
     * 服务器使用它来指示在内容协商算法中选择资源表示时使用的标头。
     * <p>
     * Mozilla贡献者在CC-BY-SA 2.5下获得了许可
     */
    public static final AsciiString VARY = AsciiString.cached("Vary");
    
    /**
     * 处理cors请求
     * <p>
     * 注意，当输入 {@link CorsConfiguration} 为 {@code null}时，该实现不会直接拒绝简单或实际的请求，
     * 而只是避免向响应添加CORS头。如果响应已经包含CORS标头，也将跳过CORS处理。
     *
     * @param corsConfiguration {@link CorsConfiguration}
     * @param request           {@link HttpRequest}
     * @param response          {@link HttpResponse}
     */
    @Override
    public Future<Boolean> processRequest(@Nullable CorsConfiguration corsConfiguration,
                                          HttpRequest request, HttpResponse response) {
        HttpHeaderMap headers = response.headers();
        
        List<String> varyHeaders = headers.getAll(VARY);
        if (varyHeaders.contains(HttpHeaders.ORIGIN.toString())) {
            headers.add(VARY, HttpHeaders.ORIGIN);
        }
        if (!varyHeaders.contains(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD.toString())) {
            headers.add(VARY, HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD);
        }
        if (!varyHeaders.contains(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS.toString())) {
            headers.add(VARY, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS);
        }
        
        // 判断是否是cors请求
        if (!CorsUtils.isCorsRequest(request)) {
            return Future.succeededFuture(true);
        }
        
        // 判断是否已经包含 Access-Control-Allow-Origin 响应头
        if (headers.get(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN) != null) {
            LogFormatUtils.traceDebug(log, traceOn -> "跳过，响应头已经包含 \"Access-Control-Allow-Origin\"");
            return Future.succeededFuture(true);
        }
        
        boolean preFlightRequest = CorsUtils.isPreFlightRequest(request);
        if (corsConfiguration == null) {
            if (preFlightRequest) {
                rejectRequest(response);
                return Future.succeededFuture(false);
            } else {
                return Future.succeededFuture(true);
            }
        }
        
        return handleInternal(request, response, corsConfiguration, preFlightRequest);
    }
    
    /**
     * 处理
     */
    protected Future<Boolean> handleInternal(HttpRequest request, HttpResponse response,
                                             CorsConfiguration corsConfiguration, boolean preFlightRequest) {
        HttpHeaderMap requestHeaderMap = request.headers();
        HttpHeaderMap responseHeaderMap = response.headers();
        
        String requestOrigin = requestHeaderMap.getOrigin();
        String allowOrigin = checkOrigin(corsConfiguration, requestOrigin);
        
        if (allowOrigin == null) {
            LogFormatUtils.traceDebug(log, traceOn -> "拒绝: 来源 '" + requestOrigin + "' 不允许");
            rejectRequest(response);
            return Future.succeededFuture(false);
        }
        
        HttpMethod requestMethod = getMethodToUse(request, preFlightRequest);
        List<HttpMethod> allowMethods = checkMethods(corsConfiguration, requestMethod);
        if (allowMethods == null) {
            LogFormatUtils.traceDebug(log, traceOn -> "拒绝: HTTP方法 '" + requestMethod + "' 不允许");
            rejectRequest(response);
            return Future.succeededFuture(false);
        }
        
        List<String> requestHeaders = getHeadersToUse(request, preFlightRequest);
        List<String> allowHeaders = checkHeaders(corsConfiguration, requestHeaders);
        if (preFlightRequest && allowHeaders == null) {
            LogFormatUtils.traceDebug(log, traceOn -> "拒绝: 请求头 [" + LogFormatUtils.formatValue(requestHeaders, !traceOn) + "] 不允许");
            rejectRequest(response);
            return Future.succeededFuture(false);
        }
        
        responseHeaderMap.setAccessControlAllowOrigin(allowOrigin);
        
        if (preFlightRequest) {
            responseHeaderMap.setAccessControlAllowMethods(allowMethods);
        }
        
        if (preFlightRequest && !allowHeaders.isEmpty()) {
            responseHeaderMap.setAccessControlAllowHeaders(allowHeaders);
        }
        
        if (CollUtil.isNotEmpty(corsConfiguration.getExposedHeaders())) {
            responseHeaderMap.setAccessControlExposeHeaders(corsConfiguration.getExposedHeaders());
        }
        
        if (Boolean.TRUE.equals(corsConfiguration.getAllowCredentials())) {
            responseHeaderMap.setAccessControlAllowCredentials(true);
        }
        
        if (preFlightRequest && corsConfiguration.getMaxAge() != null) {
            responseHeaderMap.setAccessControlMaxAge(corsConfiguration.getMaxAge());
        }
        
        return Future.succeededFuture(true);
    }
    
    /**
     * 当一个CORS检查失败时调用。
     * 默认实现将响应状态设置为403并写入"Invalid CORS request" 至响应体。
     */
    protected void rejectRequest(HttpResponse response) {
        response.setStatusCode(HttpResponseStatus.FORBIDDEN.code());
        response.setStatusMessage("Invalid CORS request");
        response.end();
    }
    
    /**
     * {@link  CorsConfiguration#checkHttpMethod(HttpMethod)}.
     */
    @Nullable
    protected List<HttpMethod> checkMethods(CorsConfiguration config, @Nullable HttpMethod requestMethod) {
        return config.checkHttpMethod(requestMethod);
    }
    
    /**
     * 检查起源并确定响应的起源。默认实现简单地委托给 {@link CorsConfiguration#checkOrigin(String)}.
     */
    @Nullable
    protected String checkOrigin(CorsConfiguration config, @Nullable String requestOrigin) {
        return config.checkOrigin(requestOrigin);
    }
    
    @Nullable
    private HttpMethod getMethodToUse(HttpRequest request, boolean isPreFlight) {
        return (isPreFlight ? request.headers().getAccessControlRequestMethod() : request.method());
    }
    
    private List<String> getHeadersToUse(HttpRequest request, boolean isPreFlight) {
        HttpHeaderMap headers = request.headers();
        return (isPreFlight ? headers.getAccessControlRequestHeaders() : new ArrayList<>(headers.keySet()));
    }
    
    @Nullable
    protected List<String> checkHeaders(CorsConfiguration config, List<String> requestHeaders) {
        return config.checkHeaders(requestHeaders);
    }
    
    
}
